{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1192e1d8-9f6b-47f2-804c-820d3773dd6d",
   "metadata": {},
   "source": [
    "# Phase gradient autofocus"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b06324e-7f14-4371-8f67-18afa9397a52",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchbp\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.signal import get_window\n",
    "import numpy as np\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    device = \"cuda\"\n",
    "else:\n",
    "    device = \"cpu\"\n",
    "print(\"Device:\", device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c6fe062-389d-4be3-b537-753b46d88f56",
   "metadata": {},
   "source": [
    "Generate synthetic radar data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4be6f29d-2c72-443a-99fd-46590d98c54f",
   "metadata": {},
   "outputs": [],
   "source": [
    "nr = 100 # Range points\n",
    "ntheta = 128 # Azimuth points\n",
    "nsweeps = 128 # Number of measurements\n",
    "fc = 6e9 # RF center frequency\n",
    "bw = 100e6 # RF bandwidth\n",
    "tsweep = 100e-6 # Sweep length\n",
    "fs = 1e6 # Sampling frequency\n",
    "nsamples = int(fs * tsweep) # Time domain samples per sweep\n",
    "\n",
    "# Imaging grid definition. Azimuth angle \"theta\" is sine of radians. 0.2 = 11.5 degrees.\n",
    "grid_polar = {\"r\": (90, 110), \"theta\": (-0.2, 0.2), \"nr\": nr, \"ntheta\": ntheta}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d107200-f4bd-4649-9999-5f756fd92fd5",
   "metadata": {},
   "outputs": [],
   "source": [
    "target_pos = torch.tensor([[100, 0, 0], [105, 10, 0], [97, -5, 0], [102, -10, 0], [95, 5, 0]], dtype=torch.float32, device=device)\n",
    "target_rcs = torch.tensor([1,1,1,1,1], dtype=torch.float32, device=device)\n",
    "pos = torch.zeros([nsweeps, 3], dtype=torch.float32, device=device)\n",
    "pos[:,1] = torch.linspace(-nsweeps/2, nsweeps/2, nsweeps) * 0.25 * 3e8 / fc\n",
    "\n",
    "# Not used in this case\n",
    "vel = torch.zeros((nsweeps, 3), dtype=torch.float32, device=device)\n",
    "att = torch.zeros((nsweeps, 3), dtype=torch.float32, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de9bcbfd-ff44-4d79-8a2e-fc7c8d480ea9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Oversampling input data decreases interpolation errors\n",
    "oversample = 3\n",
    "\n",
    "with torch.no_grad():\n",
    "    data = torchbp.util.generate_fmcw_data(target_pos, target_rcs, pos, fc, bw, tsweep, fs)\n",
    "    # Apply windowing function in range direction\n",
    "    wr = torch.tensor(get_window((\"taylor\", 3, 30), data.shape[-1])[None,:], dtype=torch.float32, device=device)\n",
    "    wa = torch.tensor(get_window((\"taylor\", 3, 30), data.shape[0])[:,None], dtype=torch.float32, device=device)\n",
    "    data = torch.fft.fft(data * wa * wr, dim=-1, n=nsamples * oversample)\n",
    "\n",
    "data_db = 20*torch.log10(torch.abs(data)).detach()\n",
    "m = torch.max(data_db)\n",
    "\n",
    "plt.figure()\n",
    "plt.imshow(data_db.cpu().numpy(), origin=\"lower\", vmin=m-30, vmax=m, aspect=\"auto\")\n",
    "plt.xlabel(\"Range samples\")\n",
    "plt.ylabel(\"Azimuth samples\");"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ddc7fc94-caff-4e75-a276-358fe4880fb5",
   "metadata": {},
   "source": [
    "Focused image without motion error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05e995c4-b67d-4a7e-a2a5-346e295034c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "r_res = 3e8 / (2 * bw * oversample) # Range bin size in input data\n",
    "\n",
    "img = torchbp.ops.backprojection_polar_2d(data, grid_polar, fc, r_res, pos, vel, att)\n",
    "img = img.squeeze(0) # Removes singular batch dimension\n",
    "# Backprojection image has spectrum with DC at zero index.\n",
    "# Shifting the spectrum shifts the DC to center bin.\n",
    "# This makes the solved phase to have same order as the position vector\n",
    "# Without shifting of the image, fftshift needs to be applied to\n",
    "# the solved phase for it to be in the same order as the position vector.\n",
    "# This doesn't affect the absolute value of the image.\n",
    "img = torchbp.util.shift_spectrum(img)\n",
    "\n",
    "img_db = 20*torch.log10(torch.abs(img)).detach()\n",
    "\n",
    "m = torch.max(img_db)\n",
    "\n",
    "extent = [*grid_polar[\"r\"], *grid_polar[\"theta\"]]\n",
    "\n",
    "plt.figure()\n",
    "plt.imshow(img_db.cpu().numpy().T, origin=\"lower\", vmin=m-40, vmax=m, extent=extent, aspect=\"auto\")\n",
    "plt.xlabel(\"Range (m)\")\n",
    "plt.ylabel(\"Angle (sin radians)\");"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29782eea-9576-4351-a630-41bf52bebd8d",
   "metadata": {},
   "source": [
    "Create corrupted image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f270d8e-634e-4fbf-9737-972367e1e608",
   "metadata": {},
   "outputs": [],
   "source": [
    "phase_error = torch.exp(1j*2*torch.pi*torch.linspace(-3, 3, ntheta, dtype=torch.float32, device=device)[None,:]**2)\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(torch.angle(phase_error.squeeze()).cpu().numpy())\n",
    "plt.xlabel(\"Azimuth sample\")\n",
    "plt.ylabel(\"Phase error (radians)\")\n",
    "\n",
    "img_corrupted = torch.fft.ifft(torch.fft.fft(img, dim=-1) * phase_error, dim=-1)\n",
    "\n",
    "plt.figure()\n",
    "plt.imshow(20*torch.log10(torch.abs(img_corrupted)).cpu().numpy().T, origin=\"lower\", vmin=m-40, vmax=m, extent=extent, aspect=\"auto\")\n",
    "plt.xlabel(\"Range (m)\")\n",
    "plt.ylabel(\"Angle (sin radians)\");"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b24ce554-3f32-4f89-bf82-f23fc561f85a",
   "metadata": {},
   "source": [
    "Phase gradient autofocus with phase difference estimator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5d7ba36-09fe-4885-b194-a1a013beae97",
   "metadata": {},
   "outputs": [],
   "source": [
    "img_pga, phi = torchbp.autofocus.pga_pd(img_corrupted, remove_trend=False)\n",
    "\n",
    "plt.figure()\n",
    "plt.imshow(20*torch.log10(torch.abs(img_pga)).cpu().numpy().T, origin=\"lower\", vmin=m-40, vmax=m, extent=extent, aspect=\"auto\")\n",
    "plt.xlabel(\"Range (m)\")\n",
    "plt.ylabel(\"Angle (sin radians)\");\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(torch.angle(torch.exp(1j*phi)).cpu().numpy())\n",
    "plt.xlabel(\"Azimuth samples\")\n",
    "plt.ylabel(\"Phase error (radians)\");"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1036ccb-2383-4f41-be81-8ae96366a314",
   "metadata": {},
   "source": [
    "Apply maximum likelihood phase gradient autofocus"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2170e6dc-c32e-4a7a-8ee5-0b9e3e1e6ee5",
   "metadata": {},
   "outputs": [],
   "source": [
    "img_pga, phi = torchbp.autofocus.pga_ml(img_corrupted, remove_trend=False)\n",
    "\n",
    "plt.figure()\n",
    "plt.imshow(20*torch.log10(torch.abs(img_pga)).cpu().numpy().T, origin=\"lower\", vmin=m-40, vmax=m, extent=extent, aspect=\"auto\")\n",
    "plt.xlabel(\"Range (m)\")\n",
    "plt.ylabel(\"Angle (sin radians)\");\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(torch.angle(torch.exp(1j*phi)).cpu().numpy())\n",
    "plt.xlabel(\"Azimuth samples\")\n",
    "plt.ylabel(\"Phase error (radians)\");"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "055e8f56-d7d2-4225-a4cd-f302de7ba3c2",
   "metadata": {},
   "source": [
    "Multiplying the solved phase with FFT of the corrupted image gives the focused image and taking inverse FFT gives the focused image. This should be identical to the image returned by `pga_ml`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45663e08-5bbf-42ca-9599-962df0dbc6d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "img_focused = torch.fft.ifft(torch.fft.fft(img_corrupted, dim=-1) * torch.exp(-1j*phi), dim=-1)\n",
    "\n",
    "plt.figure()\n",
    "plt.imshow(20*torch.log10(torch.abs(img_focused)).cpu().numpy().T, origin=\"lower\", vmin=m-40, vmax=m, extent=extent, aspect=\"auto\")\n",
    "plt.xlabel(\"Range (m)\")\n",
    "plt.ylabel(\"Angle (sin radians)\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ceddc632-0865-4369-9fdb-9abe2af8c5c6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
